import numpy as np
import random 
import numba
from numba import jit, njit

from ..tinyhouse import InputOutput

try:
    profile
except:
    def profile(x): 
        return x

class State(object):
    def __init__(self, nChromosomes, chromosomes = None):

        self.ebv = 0
        self.nChromosomes = nChromosomes
        
        if chromosomes is None:
            self.chromosomes = [None for i in range(nChromosomes)]
            self.setEbv()

        else:
            self.chromosomes = chromosomes

    def copy(self):
        return State(self.nChromosomes, self.chromosomes.copy())

    def setEbv(self):
        self.ebv = 0
        for chrom in self.chromosomes:
            if chrom is not None:
                self.ebv = self.ebv + chrom.ebv
              
    @profile
    def replaceChromosome(self, index, sampler, phenotype, multiple):
        if self.chromosomes[index] is not None:
            chr_ebv = self.chromosomes[index].ebv
        else:
            chr_ebv = 0
        residualPhenotype = phenotype - (self.ebv - chr_ebv) 
        self.chromosomes[index] = sampler.sample(residualPhenotype, multiple)
        self.setEbv()
        # self.ebv = self.ebv - chr_ebv + self.chromosomes[index].ebv

    def score(self, phenotype):
        return np.mean( np.abs(self.ebv - phenotype))

class ChromosomeSampler(object):
    @profile
    def __init__(self, chromosome, haplotypes, qtl, nSamples, start, stop):

        self.samples = [Sample(chromosome, haplotypes, start, stop) for i in range(nSamples)]
        self.nSamples = len(self.samples)
        
        sample_dosages = np.array([sample.dosages for sample in self.samples], dtype = np.float32)

        self.sample_ebvs = np.dot(sample_dosages, qtl.T)        
        self.sample_ebvs_sq = np.sum(self.sample_ebvs**2, axis = 1)

        for i, sample in enumerate(self.samples):
            sample.ebv = self.sample_ebvs[i, :]

    @profile
    def sample(self, phenotype, multiple):
        # The way sampling works, is you get scores and then select a thing proportional to post dist.

        weights = mat_distance(self.sample_ebvs, self.sample_ebvs_sq, phenotype, multiple)
        weights = expNorm_1D(weights)

        index = getSample(weights)

        # index = np.random.choice(self.nSamples, p = weights)
        return self.samples[index]

@njit
def getSample(weights):
    cutoff = np.random.random()
    count = 0
    for i in range(len(weights)):
        count += weights[i]
        if count > cutoff:
            return i
    # The following return statement shouldn't ever get hit.
    return len(weights) - 1

# See numba_distance_speed.py in pythontests for comparison of distance functions.
# This uses the "binomial trick" from https://blog.smola.org/post/969195661/in-praise-of-the-second-binomial-formula
@njit
def mat_distance(reference, ref_sq, target, multiple):

    target_sq = 0
    for i in range(len(target)):
        target_sq += target[i] ** 2

    output = np.dot(reference, target)

    for i in range(reference.shape[0]):
        output[i] = -(target_sq + ref_sq[i] - 2*output[i])/2 * multiple

    return output

@njit
def norm(vect):
    count = 0
    for i in range(len(vect)):
        count += vect[i]**2

    return count**.5

class Sample(object):
    @profile
    def __init__(self, chromosome, haplotypes, start, stop):
        self.nLoci = stop - start
        
        self.start = start
        self.stop = stop

        self.dosages = sampleAssignement(haplotypes, self.nLoci) # This gives the parental haplotype of origin for each effect.
        self.ebv = None


@njit
def sampleAssignement(haplotypes, nLoci):
    rec_rate = 1/nLoci
    dosages = np.full(nLoci, 0, dtype = np.int8)
    for j in range(2):
        currentHap = np.random.random() < .5
        # Use geometric distribution to get time until next switch.
        nextSwitch = np.random.geometric(rec_rate)

        for i in range(nLoci):
            if i == nextSwitch:
                currentHap = 1 - currentHap
                nextSwitch = i + np.random.geometric(rec_rate)
            dosages[i] += haplotypes[j, currentHap, i]
    return dosages


# @njit
# def sampleAssignement(haplotypes, nLoci):
#     rec_rate = 1/nLoci
#     dosages = np.full(nLoci, 0, dtype = np.int8)
#     for j in range(2):
#         currentHap = np.random.random() < .5
#         # Use geometric distribution to get time until next switch.
#         i = 0
#         while i < nLoci:
#             currentHap = 1 - currentHap
#             nextSwitch = i + np.random.geometric(rec_rate)
#             if nextSwitch > nLoci:
#                 nextSwitch = nLoci
#             while i < nextSwitch:           
#                 dosages[i] += haplotypes[j, currentHap, i]
#                 i += 1
#     return dosages


@njit
def getEbvFromAssignments(assignments, qtleffects):
    nLoci = qtleffects.shape[2]
    nQtl = qtleffects.shape[3]
    phenotype_pred = np.full(nQtl, 0, dtype = np.float32)    
    for i in range(2):
        for j in range(nLoci):
            # The haplotype assignment is assignments[i, j]
            # i iterates over parents (sire = 0, dam = 1)
            # qtleffects is mother/father, maternal/paternal hap, loci, qtl
            qtl = qtleffects[i, assignments[i, j], j, :]
            for k in range(nQtl):
                phenotype_pred[k] += qtl[k]

            # phenotype_pred += qtleffects[i, 0, j] * (1-assignments[i, j])
            # phenotype_pred += qtleffects[i, 1, j] * assignments[i, j] 

            # if assignments[i,j] == 0:
            #     phenotype_pred += qtleffects[i, 0, j] * (1-assignments[i, j])
            # else:
            #     phenotype_pred += qtleffects[i, 1, j] * assignments[i, j] 

    return phenotype_pred

@profile
def simple_sampler(phenotype, chr_samplers, nIter = 5000, nThin = 100, nBurnIn = 1000):
    nChromosomes = len(chr_samplers)
    samples = []


    currentState = State(nChromosomes, None)
    for index in range(nChromosomes):
        currentState.replaceChromosome(index, chr_samplers[index], phenotype, multiple = 1)

    chromOrderings = np.random.randint(0, nChromosomes, size = nIter)
    for i in range(nIter):
        index = chromOrderings[i]
        currentState.replaceChromosome(index, chr_samplers[index], phenotype, multiple = (1 - i/nIter)*InputOutput.args.multiple)
        if i % nThin == 0 and i > nBurnIn:
            samples.append(currentState.copy())
            print(i, currentState.score(phenotype))

    return samples

def getEffects(haplotype, qtls):
    nLoci = haplotype.shape[0]
    effects = np.full((haplotype.shape[0], qtls.shape[0]), 0, dtype = np.float32)
    for i in range(nLoci):
        effects[i,:] = haplotype[i]*qtls[:, i]
    return effects

@profile
def constructSamplers(haplotypes, qtls, chrMap) :
    # Need to handle map better.
    chrMap = np.array(chrMap)
    nChr = np.max(np.unique(chrMap))
    samplers = []
    for val in range(nChr):
        mask = chrMap == (val + 1) # Since we start chromosomes at 0.
        start = np.argmax(mask)
        stop = len(mask) - np.argmax(mask[::-1])
        
        print(val, np.sum(mask), start, stop)

        chr_samplers = ChromosomeSampler(val, haplotypes[:,:,start:stop], qtls[:,start:stop], InputOutput.args.samplesize, start, stop)
        samplers.append(chr_samplers)

    return samplers

@profile
def imputeIndividualFromPhenotypes(ind, sire, dam, qtls, chrMap):
    # Yeah, the only thing we need from ind is their phenotype.

    nQtl, nLoci = qtls.shape
    
    haplotypes = np.full((2, 2, nLoci), 0, dtype = np.float32)
    haplotypes[0, 0, :] = sire.haplotypes[0]
    haplotypes[0, 1, :] = sire.haplotypes[1]
    haplotypes[1, 0, :] = dam.haplotypes[0]
    haplotypes[1, 1, :] = dam.haplotypes[1]

    chr_samplers = constructSamplers(haplotypes, qtls, chrMap)

    # Then sample from the samples, or something.
    post_states = simple_sampler(ind.phenotype, chr_samplers, nIter = 10000, nThin = 100, nBurnIn = 2000)
    ind.dosages = np.full(nLoci, 0, dtype = np.float32)

    for state in post_states:
        weight = 1/len(post_states)
        for sample in state.chromosomes:
            addDosage(weight, ind.dosages, sample.dosages, sample.start, sample.stop)
@njit
def addDosage(weight, dosages, sample_dosages, start, stop):
    for index in range(stop - start):
        i = start + index
        dosages[i] += sample_dosages[index]


@njit()
def expNorm_1D(mat):
    # Matrix is 4: Output is to take the exponential of the matrix and normalize each locus. We need to make sure that there are not any overflow values.
    # Note, this changes the matrix in place by a constant.

    maxVal = 1 # Log of anything between 0-1 will be less than 0. Using 1 as a default.
    for a in range(len(mat)):
        if mat[a] > maxVal or maxVal == 1:
            maxVal = mat[a]
    for a in range(len(mat)):
        mat[a] -= maxVal

    # Should flag for better numba-ness.
    tmp = np.exp(mat)
    total = 0
    for a in range(len(mat)):
        total += tmp[a]
    for a in range(len(mat)):
        tmp[a] /= total
    
    return tmp
